#define ELEMWISE_BOUNDS_AND_INDEX \
  int idx = threadIdx.x + blockIdx.x * blockDim.x; \
  if (idx >= len) \
    return

template <typename T>
__device__ void add_scal(T *array, T scal, int len) {
  ELEMWISE_BOUNDS_AND_INDEX;
  array[idx] += scal;
}
template <typename T>
__device__ void mul_scal(T *array, T scal, int len) {
  ELEMWISE_BOUNDS_AND_INDEX;
  array[idx] *= scal;
}

#define DEF_ELEMWISE_OP(NAME, OP) \
  template <typename T> \
  __device__ void elem_ ## NAME(T *X, T *Y, int len) { \
    ELEMWISE_BOUNDS_AND_INDEX; \
    X[idx] = X[idx] OP Y[idx]; \
  }

DEF_ELEMWISE_OP(mul, *)
DEF_ELEMWISE_OP(add, +)
DEF_ELEMWISE_OP(sub, -)

template <typename T>
__device__ void elem_div(T *X, T *Y, int len) {
  ELEMWISE_BOUNDS_AND_INDEX;
  if (abs(Y[idx]) < 1e-20) {
    X[idx] = X[idx] / 1e-20;
    if (Y[idx] < 0)
      X[idx] = -X[idx];
  } else {
    X[idx] = X[idx] / Y[idx];
  }
}

template <typename T>
__device__ void elem_div2(T *X, T *Y, int len) {
  ELEMWISE_BOUNDS_AND_INDEX;
  if (abs(Y[idx]) < 1e-20) {
    if (Y[idx] < 0)
      Y[idx] = X[idx] / -1e-20;
    else
      Y[idx] = X[idx] / 1e-20;
  } else {
    Y[idx] = X[idx] / Y[idx];
  }
}

template <typename T1, typename T2>
__device__ void elem_pow(T1 *X, T2 p, int len) {
  ELEMWISE_BOUNDS_AND_INDEX;
  X[idx] = pow(X[idx], p);
}

template <typename T>
__device__ void elem_log(T *X, int len) {
  ELEMWISE_BOUNDS_AND_INDEX;
  X[idx] = log(X[idx] > 1e-20 ? X[idx] : 1e-20);
}

template <typename T>
__device__ void elem_exp(T *X, int len) {
  ELEMWISE_BOUNDS_AND_INDEX;
  X[idx] = exp(X[idx]);
}

template <typename T>
__device__ void elem___expf(T *X, int len) {
  ELEMWISE_BOUNDS_AND_INDEX;
  X[idx] = __expf(X[idx]);
}


#define DEF_ELEMWISE_API(NAME) \
  __global__ void elem_ ## NAME ## _float(float *X, float *Y, int len) { \
    elem_##NAME(X, Y, len); \
  } \
  __global__ void elem_ ## NAME ## _double(double *X, double *Y, int len) { \
    elem_##NAME(X, Y, len); \
  }

extern "C" {

DEF_ELEMWISE_API(mul)
DEF_ELEMWISE_API(add)
DEF_ELEMWISE_API(sub)
DEF_ELEMWISE_API(div)
DEF_ELEMWISE_API(div2)


__global__ void add_scal_float(float *X, float Y, int len) {
  add_scal(X, Y, len);
}
__global__ void add_scal_double(double *X, double Y, int len) {
  add_scal(X, Y, len);
}

__global__ void mul_scal_float(float *X, float Y, int len) {
  mul_scal(X, Y, len);
}
__global__ void mul_scal_double(double *X, double Y, int len) {
  mul_scal(X, Y, len);
}

__global__ void elem_pow_fi(float *X, int p, int len) {
  elem_pow(X, p, len);
}
__global__ void elem_pow_di(double *X, int p, int len) {
  elem_pow(X, p, len);
}
__global__ void elem_pow_ff(float *X, float p, int len) {
  elem_pow(X, p, len);
}
__global__ void elem_pow_dd(double *X, double p, int len) {
  elem_pow(X, p, len);
}

__global__ void elem_log_double(double *X, int len) {
  elem_log(X, len);
}
__global__ void elem_log_float(float *X, int len) {
  elem_log(X, len);
}

__global__ void elem_exp_double(double *X, int len) {
  elem_exp(X, len);
}
__global__ void elem_exp_float(float *X, int len) {
  elem___expf(X, len);
}


} // extern "C"

// vim: ft=cuda
